import java.util.Arrays;

public class App {
    public static int findMiddleIndex(int[] nums) {
        int total = Arrays.stream(nums).sum();
        int prefixSum = 0;
        for (int i = 0; i < nums.length; i++) {
            if (prefixSum == total - prefixSum - nums[i]) {
                return i;
            }
            prefixSum += nums[i];
        }
        return -1;
    }

    public static void main(String[] args) throws Exception {
        int[] nums = {2, 3, -1, 8, 4};
        System.out.println(findMiddleIndex(nums));
    }
}
